{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# softmax和分类模型\n",
    "内容包含：  \n",
    "1. softmax回归的概念\n",
    "2. 获取Fashion-MNIST数据集\n",
    "3. softmax回归模型实现，对图像数据进行分类"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# softmax概念\n",
    "- 分类问题  \n",
    "一个简单的图像分类问题，输入图像高宽各2像素，色彩为灰度。  \n",
    "图像的4像素分别记作$x_1,x_2,x_3,x_4$。  \n",
    "假设真实标签有三类，分别为猫、狗、鸡，对应的离散值为$y_1=1,y_2=2,y_3=3$。  \n",
    "- 权重矢量\n",
    "$$o_1=x_1w_11+x_2w_21+x_3w_31+x_4w_41+b_1$$\n",
    "$$o_2=x_1w_12+x_2w_22+x_3w_32+x_4w_42+b_2$$\n",
    "$$o_3=x_1w_13+x_2w_23+x_3w_33+x_4w_43+b_3$$\n",
    "- 神经网络图  \n",
    "softmax回归同线性回归一样，也是一个单层神经网络。由于每个输出$o_1,o_2,o_3$的计算都要依赖于所有的输入$x_1,x_2,x_3,x_4$，softmax回归的输出层也是一个全连接层。  \n",
    "![soft max neural network](https://staticcdn.boyuai.com/materials/2020/02/15/HOfTeZu5l4-dFL5XRA9y1.png!png)\n",
    "既然分类问题需要得到离散的预测输出，一个简单的办法是将输出值$o_i$当作预测类别是$i$的置信度，并将值最大的输出所对应的类作为预测输出，即输出$\\mathop{arg\\,maxo_i}\\limits_{i}$。  \n",
    "例如，如果$o_1,o_2,o_3$分别为$0.1,10,0.1$，由于$o_2$最大，那么预测类别为2，其代表猫。  \n",
    "- 输出问题  \n",
    "直接使用输出层的输出有两个问题：  \n",
    "    1. 一方面，由于输出层的输出值的范围不确定，难以直观上判断这些值的意义。\n",
    "    2. 另一方面，由于真实标签是离散值，这些离散值与不确定范围的输出值之间的误差难以衡量。  \n",
    "softmax运算符(softmax operator)解决了以上两个问题。它通过下式将输出值变换成值为正且和为1的概率分布：\n",
    "$$\\hat{y_1},\\hat{y_2},\\hat{y_3}=softmax(o_1,o_2,o_3)$$\n",
    "其中\n",
    "$$\\hat{y_1}=\\frac{exp(o_1)}{\\sum_{i=1}^3exp(o_i)}$$\n",
    "$$\\hat{y_2}=\\frac{exp(o_2)}{\\sum_{i=1}^3exp(o_i)}$$\n",
    "$$\\hat{y_3}=\\frac{exp(o_3)}{\\sum_{i=1}^3exp(o_i)}$$\n",
    "容易看出$\\hat{y_1}+\\hat{y_2}+\\hat{y_3}=1$且$0\\leq\\hat{y_1},\\hat{y_2},\\hat{y_3}\\leq1$，因此$\\hat{y_1},\\hat{y_2},\\hat{y_3}$是一个合法的概率分布。此外，\n",
    "$$\\mathop{arg\\,maxo_i}\\limits_{i}=\\mathop{arg\\,max\\hat{y_i}}\\limits_{i}$$\n",
    "因此softmax运算不改变预测类别输出。  \n",
    "- 计算效率\n",
    "    - 单样本矢量计算表达式\n",
    "    $$W=\n",
    "    \\begin{bmatrix}\n",
    "     w_11 & w_12 & w_13\\\\\n",
    "     w_21 & w_22 & w_23\\\\\n",
    "     w_31 & w_32 & w_33\\\\\n",
    "     w_41 & w_42 & w_43\\\\\n",
    "    \\end{bmatrix},\\,\n",
    "     b=\\begin{bmatrix} b_1 & b_2 & b_3 \\end{bmatrix}\n",
    "    $$\n",
    "设高和宽分别为2个像素的图像样本$i$的特征为\n",
    "$$x^{(i)}=\\begin{bmatrix} x^{(i)}_1 & x^{(i)}_2 & x^{(i)}_3 & x^{(i)}_4 \\end{bmatrix}$$\n",
    "输出层的输出为\n",
    "$$o^{(i)}=\\begin{bmatrix} o^{(i)}_1 & o^{(i)}_2 & o^{(i)}_3 \\end{bmatrix}$$\n",
    "预测为猫、狗或鸡的概率分布为\n",
    "$$\\hat{y}^{(i)}=\\begin{bmatrix} \\hat{y}^{(i)}_1 & \\hat{y}^{(i)}_2 & \\hat{y}^{(i)}_3 \\end{bmatrix}$$\n",
    "softmax回归对样本$i$分类的矢量计算表达式为\n",
    "$$o^{(i)}=x^{(i)}W+b$$\n",
    "$$\\hat{y}^{(i)}=softmax(o^{(i)})$$\n",
    "- 小批量矢量计算表达式\n",
    "为了进一步提升计算效率，通常对小批量数据做矢量计算。广义上讲，给定一个小批量样本，其批量大小为$n$，输入个数（特征数）为$d$，输出个数（类别墅）为$q$。设批量特征为$X\\in \\mathbb{R}^{n\\times d}$。假设softmax回归的权重和偏差参数分别为$W\\in \\mathbb{R}^{d\\times q}$和$b\\in \\mathbb{R}^{1\\times q}$。softmax回归的矢量计算表达式为\n",
    "$$O=XW+b$$\n",
    "$$\\hat{Y}=softmax(O)$$\n",
    "其中的加法运算使用了广播机制，$O,\\hat{Y}\\in \\mathbb{R}^{n\\times q}$且这两个矩阵的第$i$行分别为样本$i$的输出$O^{(i)}$和概率分布$\\hat{y}^{(i)}$。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 交叉熵损失函数\n",
    "对于样本$i$，我们构造向量$y_{(i)}\\in \\mathbb{R}^{q}$，使其第$y_{(i)}$（样本$i$类别的离散数值）个元素为1，其余为0。这样我们的训练目标可以设为使预测概率分布$\\hat{y}^{(i)}$尽可能接近真实的标签概率分布$y^{(i)}$。\n",
    "- 平方损失估计\n",
    "$$Loss=|\\hat{y}^{(i)}-y^{(i)}|^2/2$$\n",
    "- 交叉熵（cross entropy）更适合衡量两个概率分布差异\n",
    "$$H(y^{(i)},\\hat{y}^{(i)})=-\\sum_{j=1}^qy^{(i)}_jlog\\hat{y}^{(i)}_{j}$$\n",
    "其中$y^{(i)}_j$是向量$y^{(i)}$中非0即1的元素，需要注意将它与样本$i$类别的离散值，即$y^{(i)}$区分。在上式中，我们知道向量$y^{(i)}$中只有第$y^{(i)}$个元素$y^{(i)}_{y^{(i)}}$为1，其余全为0，于是$H(y^{(i)},\\hat{y}^{(i)})=-log\\hat{y}^{(i)}_{y^{(i)}}$。也就是说，交叉熵只关心对正确类别的预测概率，因为只要其值最勾搭，就可以确保分类结果正确。当然，遇到一个样本有多个标签时，例如图像里含有不止一个物体时，并不能做这一步简化。但即便对于这种情况，交叉熵同样只关心对图像中出现的物体类别的预测概率。  \n",
    "假设训练数据集的样本数为$n$，交叉熵损失函数定义为\n",
    "$${\\scr{l}}(\\Theta)=\\frac{1}{n}\\sum_{i=1}^nH(y^{(i)},\\hat{y}^{(i)})$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 获取Fashion-MNIST数据集\n",
    "使用torchvision：\n",
    "1. .datasets：一些加载数据的函数和常用的数据集接口\n",
    "2. .models：常用的模型结构（含预训练模型），例如AlexNet,VGG,ResNet;\n",
    "3. .transforms：常用的图片变换\n",
    "4. .utils： 其他方法等"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'd2lzh1981'",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-8-b772274596fe>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     11\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0msys\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     12\u001b[0m \u001b[0msys\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"/home/kesci/input\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 13\u001b[1;33m \u001b[1;32mimport\u001b[0m \u001b[0md2lzh1981\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0md21\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'd2lzh1981'"
     ]
    }
   ],
   "source": [
    "# 导入库\n",
    "%matplotlib inline\n",
    "from IPython import display\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import time \n",
    "\n",
    "import sys\n",
    "sys.path.append(\"/home/kesci/input\")\n",
    "import d2lzh1981 as d21"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###  获取数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to /home/kesci/input/FashionMNIST2065\\FashionMNIST\\raw\\train-images-idx3-ubyte.gz\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "60c87563ee46474ead982f6cae6570b4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-10-28355a895aaa>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mmnist_train\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtorchvision\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdatasets\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mFashionMNIST\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mroot\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'/home/kesci/input/FashionMNIST2065'\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mtrain\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mdownload\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mtransform\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtransforms\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mToTensor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      2\u001b[0m \u001b[0mmnist_test\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtorchvision\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdatasets\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mFashionMNIST\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mroot\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'/home/kesci/input/FashionMNIST2065'\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mtrain\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mdownload\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mtransform\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtransforms\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mToTensor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\Anaconda3\\lib\\site-packages\\torchvision\\datasets\\mnist.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, root, train, transform, target_transform, download)\u001b[0m\n\u001b[0;32m     68\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     69\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mdownload\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 70\u001b[1;33m             \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdownload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     71\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     72\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_check_exists\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\Anaconda3\\lib\\site-packages\\torchvision\\datasets\\mnist.py\u001b[0m in \u001b[0;36mdownload\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    135\u001b[0m         \u001b[1;32mfor\u001b[0m \u001b[0murl\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmd5\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mresources\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    136\u001b[0m             \u001b[0mfilename\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0murl\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrpartition\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'/'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 137\u001b[1;33m             \u001b[0mdownload_and_extract_archive\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0murl\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdownload_root\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mraw_folder\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mfilename\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmd5\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mmd5\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    138\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    139\u001b[0m         \u001b[1;31m# process and save as torch files\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\Anaconda3\\lib\\site-packages\\torchvision\\datasets\\utils.py\u001b[0m in \u001b[0;36mdownload_and_extract_archive\u001b[1;34m(url, download_root, extract_root, filename, md5, remove_finished)\u001b[0m\n\u001b[0;32m    262\u001b[0m         \u001b[0mfilename\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mos\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbasename\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0murl\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    263\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 264\u001b[1;33m     \u001b[0mdownload_url\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0murl\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdownload_root\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmd5\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    265\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    266\u001b[0m     \u001b[0marchive\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mos\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdownload_root\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\Anaconda3\\lib\\site-packages\\torchvision\\datasets\\utils.py\u001b[0m in \u001b[0;36mdownload_url\u001b[1;34m(url, root, filename, md5)\u001b[0m\n\u001b[0;32m     83\u001b[0m             urllib.request.urlretrieve(\n\u001b[0;32m     84\u001b[0m                 \u001b[0murl\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfpath\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 85\u001b[1;33m                 \u001b[0mreporthook\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mgen_bar_updater\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     86\u001b[0m             )\n\u001b[0;32m     87\u001b[0m         \u001b[1;32mexcept\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0murllib\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0merror\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mURLError\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mIOError\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\Anaconda3\\lib\\urllib\\request.py\u001b[0m in \u001b[0;36murlretrieve\u001b[1;34m(url, filename, reporthook, data)\u001b[0m\n\u001b[0;32m    274\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    275\u001b[0m             \u001b[1;32mwhile\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 276\u001b[1;33m                 \u001b[0mblock\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mfp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mread\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    277\u001b[0m                 \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mblock\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    278\u001b[0m                     \u001b[1;32mbreak\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\Anaconda3\\lib\\http\\client.py\u001b[0m in \u001b[0;36mread\u001b[1;34m(self, amt)\u001b[0m\n\u001b[0;32m    455\u001b[0m             \u001b[1;31m# Amount is given, implement using readinto\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    456\u001b[0m             \u001b[0mb\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mbytearray\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mamt\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 457\u001b[1;33m             \u001b[0mn\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mreadinto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mb\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    458\u001b[0m             \u001b[1;32mreturn\u001b[0m \u001b[0mmemoryview\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mb\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mn\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtobytes\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    459\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\Anaconda3\\lib\\http\\client.py\u001b[0m in \u001b[0;36mreadinto\u001b[1;34m(self, b)\u001b[0m\n\u001b[0;32m    499\u001b[0m         \u001b[1;31m# connection, and the user is reading more bytes than will be provided\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    500\u001b[0m         \u001b[1;31m# (for example, reading in 1k chunks)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 501\u001b[1;33m         \u001b[0mn\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mreadinto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mb\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    502\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mn\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0mb\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    503\u001b[0m             \u001b[1;31m# Ideally, we would raise IncompleteRead if the content-length\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\Anaconda3\\lib\\socket.py\u001b[0m in \u001b[0;36mreadinto\u001b[1;34m(self, b)\u001b[0m\n\u001b[0;32m    587\u001b[0m         \u001b[1;32mwhile\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    588\u001b[0m             \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 589\u001b[1;33m                 \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_sock\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrecv_into\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mb\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    590\u001b[0m             \u001b[1;32mexcept\u001b[0m \u001b[0mtimeout\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    591\u001b[0m                 \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_timeout_occurred\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "mnist_train=torchvision.datasets.FashionMNIST(root='/home/kesci/input/FashionMNIST2065',train=True,download=True,transform=transforms.ToTensor())\n",
    "mnist_test=torchvision.datasets.FashionMNIST(root='/home/kesci/input/FashionMNIST2065',train=False,download=True,transform=transforms.ToTensor())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- root（string）– 数据集的根目录，其中存放processed/training.pt和processed/test.pt文件。\n",
    "- train（bool, 可选）– 如果设置为True，从training.pt创建数据集，否则从test.pt创建。\n",
    "- download（bool, 可选）– 如果设置为True，从互联网下载数据并放到root文件夹下。如果root目录下已经存在数据，不会再次下载。\n",
    "- transform（可被调用 , 可选）– 一种函数或变换，输入PIL图片，返回变换之后的数据。如：transforms.RandomCrop。\n",
    "- target_transform（可被调用 , 可选）– 一种函数或变换，输入目标，进行变换。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# softmax从零开始实现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import numpy as np\n",
    "import sys\n",
    "sys.path.append(\"/home/kesci/input\")\n",
    "import d2lzh as d2l"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 获取训练集和测试集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE=256\n",
    "train_iter,test_iter=d2l.load_data_fashion_mnist(BATCH_SIZE,root='/home/kesci/input/FashionMNIST2065')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 模型参数初始化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_INPUTS= 784\n",
    "NUM_OUTPUTS=10\n",
    "\n",
    "W=torch.tensor(np.random.normal(0,0.01,(NUM_INPUTS,NUM_OUTPUTS)),dtype=torch.float)\n",
    "b=torch.zeros(NUM_OUTPUTS,dtype=torch.float)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "W.requires_grad_(requires_grad=True)\n",
    "b.requires_grad_(requires_grad=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 对多维Tensor按维度操作"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1)a tensor([[5, 7, 9]])\n",
      "2)a tensor([[ 6],\n",
      "        [15]])\n",
      "1)b tensor([5, 7, 9])\n",
      "1)b tensor([ 6, 15])\n"
     ]
    }
   ],
   "source": [
    "X = torch.tensor([[1, 2, 3], [4, 5, 6]])\n",
    "print(\"1)a\",X.sum(dim=0, keepdim=True))  # dim为0，按照相同的列求和，并在结果中保留列特征\n",
    "print(\"2)a\",X.sum(dim=1, keepdim=True))  # dim为1，按照相同的行求和，并在结果中保留行特征\n",
    "print(\"1)b\",X.sum(dim=0, keepdim=False)) # dim为0，按照相同的列求和，不在结果中保留列特征\n",
    "print(\"1)b\",X.sum(dim=1, keepdim=False)) # dim为1，按照相同的行求和，不在结果中保留行特征"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 定义softmax操作\n",
    "$$\n",
    "\\hat{y}_{j}=\\frac{exp(o_{j})}{\\sum_{i=1}^3exp(o_{i})}\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def softmax(X):\n",
    "    X_exp = X.exp()\n",
    "    partition = X_exp.sum(dim=1, keepdim=True)\n",
    "    # print(\"X size is \", X_exp.size())\n",
    "    # print(\"partition size is \", partition, partition.size())\n",
    "    return X_exp / partition  # 这里应用了广播机制"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0.2059, 0.1872, 0.3097, 0.1349, 0.1624],\n",
      "        [0.1714, 0.1765, 0.1789, 0.2168, 0.2563]]) \n",
      " tensor([1.0000, 1.0000])\n"
     ]
    }
   ],
   "source": [
    "X = torch.rand((2, 5))\n",
    "X_prob = softmax(X)\n",
    "print(X_prob, '\\n', X_prob.sum(dim=1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# softmax回归模型\n",
    "$$\n",
    "{\\bf{o}}^{(i)}={\\bf{x}}^{(i)}{\\bf{W}}+{\\bf{b}},\\\\\n",
    "{\\bf{\\hat{y}}}^{(i)}=softmax(o^{(i)}).\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def softmaxNet(X):\n",
    "    return softmax(torch.mm(X.view((-1, num_inputs)), W) + b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 定义损失函数\n",
    "$$\n",
    "H({\\bf{y}}^{(i)},{\\bf{\\hat{y}}}^{(i)})=-\\sum_{j=1}^qy^{(i)}_{j}log\\hat{y}^{(i)}_{j},\\\\\n",
    "{\\scr{l}}({\\bf{\\Theta}})=\\frac{1}{n}\\sum_{i=1}^nH({\\bf{y}}^{(i)},{\\bf{\\hat{y}}}^{(i)}),\\\\\n",
    "{\\scr{l}}({\\bf{\\Theta}})=-\\frac{1}{n}\\sum_{i=1}^nlog\\hat{y}^{(i)}_{y^{(i)}}\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "torch.gather:parameters\n",
    "- dim (python:int) – the axis along which to index\n",
    "- index (LongTensor) – the indices of elements to gather"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.1000],\n",
       "        [0.5000]])"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_hat=torch.tensor([[0.1,0.3,0.6],[0.3,0.2,0.5]])\n",
    "y=torch.LongTensor([0,2])\n",
    "y.size()\n",
    "# y.view(-1)\n",
    "y_hat.gather(1,y.view(-1,1))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cross_entropy(y_hat,y):\n",
    "    return -torch.log(y_hat.gather(1,y.view(-1,1)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 定义准确率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def accuracy(y_hat,y):\n",
    "    return (y_hat.argmax(dim=1)==y).float().mean().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5\n"
     ]
    }
   ],
   "source": [
    "print(accuracy(y_hat, y))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_accuracy(data_iter, net):\n",
    "    acc_sum, n = 0.0, 0\n",
    "    for X, y in data_iter:\n",
    "        acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()\n",
    "        n += y.shape[0]\n",
    "    return acc_sum / n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "'NDArray' object has no attribute 'view'",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-10-e409d42a5980>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mevaluate_accuracy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtest_iter\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0msoftmaxNet\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;32m<ipython-input-9-0203502541a1>\u001b[0m in \u001b[0;36mevaluate_accuracy\u001b[1;34m(data_iter, net)\u001b[0m\n\u001b[0;32m      2\u001b[0m     \u001b[0macc_sum\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mn\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m0.0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      3\u001b[0m     \u001b[1;32mfor\u001b[0m \u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mdata_iter\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 4\u001b[1;33m         \u001b[0macc_sum\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mnet\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0margmax\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msum\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      5\u001b[0m         \u001b[0mn\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      6\u001b[0m     \u001b[1;32mreturn\u001b[0m \u001b[0macc_sum\u001b[0m \u001b[1;33m/\u001b[0m \u001b[0mn\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-6-63b7fa6fd45b>\u001b[0m in \u001b[0;36msoftmaxNet\u001b[1;34m(X)\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0msoftmaxNet\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m     \u001b[1;32mreturn\u001b[0m \u001b[0msoftmax\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmm\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mview\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnum_inputs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mW\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mb\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;31mAttributeError\u001b[0m: 'NDArray' object has no attribute 'view'"
     ]
    }
   ],
   "source": [
    "print(evaluate_accuracy(test_iter,softmaxNet))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "'NDArray' object has no attribute 'view'",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-75-8afec028edc7>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     25\u001b[0m         print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'\n\u001b[0;32m     26\u001b[0m               % (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))\n\u001b[1;32m---> 27\u001b[1;33m \u001b[0mtrain_ch3\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnet\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtrain_iter\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtest_iter\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcross_entropy\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnum_epochs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0mW\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mb\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;32m<ipython-input-75-8afec028edc7>\u001b[0m in \u001b[0;36mtrain_ch3\u001b[1;34m(net, train_iter, test_iter, loss, num_epochs, batch_size, params, lr, optimizer)\u001b[0m\n\u001b[0;32m      5\u001b[0m         \u001b[0mtrain_l_sum\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mtrain_acc_sum\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mn\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m0.0\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m0.0\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      6\u001b[0m         \u001b[1;32mfor\u001b[0m \u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0my\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mtrain_iter\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 7\u001b[1;33m             \u001b[0my_hat\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mnet\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      8\u001b[0m             \u001b[0ml\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mloss\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0my_hat\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0my\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msum\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      9\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-21-944c843c42eb>\u001b[0m in \u001b[0;36mnet\u001b[1;34m(X)\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mnet\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m     \u001b[1;32mreturn\u001b[0m \u001b[0msoftmax\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmm\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mview\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mNUM_INPUTS\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mW\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m+\u001b[0m\u001b[0mb\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;31mAttributeError\u001b[0m: 'NDArray' object has no attribute 'view'"
     ]
    }
   ],
   "source": [
    "num_epochs,lr=5,0.1\n",
    "batch_size=256\n",
    "def train_ch3(net,train_iter,test_iter,loss,num_epochs,batch_size,params=None,lr=None,optimizer=None):\n",
    "    for epoch in range(num_epochs):\n",
    "        train_l_sum,train_acc_sum,n=0.0,0.0,0\n",
    "        for X,y in train_iter:\n",
    "            y_hat=net(X)\n",
    "            l=loss(y_hat,y).sum()\n",
    "            \n",
    "            if optimizer is not None:\n",
    "                optimizer.zero_grad()\n",
    "            elif params is not None and params[0].grad is not None:\n",
    "                for param in params:\n",
    "                    param.grad.data.zero_()\n",
    "                    \n",
    "            l.backward()\n",
    "            if optimizer is None:\n",
    "                d2l.sgd(params,lr,batch_size)\n",
    "            else:\n",
    "                optimizer.step()\n",
    "            train_l_sum+=l.item()\n",
    "            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()\n",
    "            n += y.shape[0]\n",
    "        test_acc = evaluate_accuracy(test_iter, net)\n",
    "        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'\n",
    "              % (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))\n",
    "train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, batch_size, [W, b], lr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#  模型预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "'generator' object has no attribute 'next'",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-11-6d43c8fd6a29>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0miter\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtest_iter\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnext\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      2\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      3\u001b[0m \u001b[0mtrue_labels\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0md2l\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_fashion_mnist_labels\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0my\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      4\u001b[0m \u001b[0mpred_labels\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0md2l\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_fashion_mnist_labels\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnet\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0margmax\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      5\u001b[0m \u001b[0mtitles\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0mtrue\u001b[0m \u001b[1;33m+\u001b[0m \u001b[1;34m'\\n'\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mpred\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mtrue\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpred\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtrue_labels\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpred_labels\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mAttributeError\u001b[0m: 'generator' object has no attribute 'next'"
     ]
    }
   ],
   "source": [
    "X, y = iter(test_iter).next()\n",
    "\n",
    "true_labels = d2l.get_fashion_mnist_labels(y.numpy())\n",
    "pred_labels = d2l.get_fashion_mnist_labels(net(X).argmax(dim=1).numpy())\n",
    "titles = [true + '\\n' + pred for true, pred in zip(true_labels, pred_labels)]\n",
    "\n",
    "d2l.show_fashion_mnist(X[0:9], titles[0:9])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# softmax的PyTorch实现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.4.0\n"
     ]
    }
   ],
   "source": [
    "# 加载各种包或者模块\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.nn import init\n",
    "import numpy as np\n",
    "import sys\n",
    "sys.path.append(\"/home/kesci/input\")\n",
    "import d2lzh as d2l\n",
    "\n",
    "print(torch.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 初始化参数和获取数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 256\n",
    "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, root='/home/kesci/input/FashionMNIST2065')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 定义网络模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_inputs = 784\n",
    "num_outputs = 10\n",
    "\n",
    "class LinearNet(nn.Module):\n",
    "    def __init__(self, num_inputs, num_outputs):\n",
    "        super(LinearNet, self).__init__()\n",
    "        self.linear = nn.Linear(num_inputs, num_outputs)\n",
    "    def forward(self, x): # x 的形状: (batch, 1, 28, 28)\n",
    "        y = self.linear(x.view(x.shape[0], -1))\n",
    "        return y\n",
    "    \n",
    "# net = LinearNet(num_inputs, num_outputs)\n",
    "\n",
    "class FlattenLayer(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(FlattenLayer, self).__init__()\n",
    "    def forward(self, x): # x 的形状: (batch, *, *, ...)\n",
    "        return x.view(x.shape[0], -1)\n",
    "\n",
    "from collections import OrderedDict\n",
    "net = nn.Sequential(\n",
    "        # FlattenLayer(),\n",
    "        # LinearNet(num_inputs, num_outputs) \n",
    "        OrderedDict([\n",
    "           ('flatten', FlattenLayer()),\n",
    "           ('linear', nn.Linear(num_inputs, num_outputs))]) # 或者写成我们自己定义的 LinearNet(num_inputs, num_outputs) 也可以\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 初始化模型参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "init.normal_(net.linear.weight, mean=0, std=0.01)\n",
    "init.constant_(net.linear.bias, val=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 定义损失函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss = nn.CrossEntropyLoss() # 下面是他的函数原型\n",
    "# class torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 定义优化函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = torch.optim.SGD(net.parameters(), lr=0.1) # 下面是函数原型\n",
    "# class torch.optim.SGD(params, lr=, momentum=0, dampening=0, weight_decay=0, nesterov=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "'NDArray' object has no attribute 'view'",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-20-4f5dfa17aa83>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[0mnum_epochs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m5\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m \u001b[0md2l\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtrain_ch3\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnet\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtrain_iter\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtest_iter\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mloss\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnum_epochs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;32m~\\Anaconda3\\lib\\site-packages\\d2lzh\\utils.py\u001b[0m in \u001b[0;36mtrain_ch3\u001b[1;34m(net, train_iter, test_iter, loss, num_epochs, batch_size, params, lr, trainer)\u001b[0m\n\u001b[0;32m    659\u001b[0m         \u001b[1;32mfor\u001b[0m \u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mtrain_iter\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    660\u001b[0m             \u001b[1;32mwith\u001b[0m \u001b[0mautograd\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrecord\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 661\u001b[1;33m                 \u001b[0my_hat\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnet\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    662\u001b[0m                 \u001b[0ml\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mloss\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0my_hat\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msum\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    663\u001b[0m             \u001b[0ml\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\Anaconda3\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m    530\u001b[0m             \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    531\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 532\u001b[1;33m             \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    533\u001b[0m         \u001b[1;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    534\u001b[0m             \u001b[0mhook_result\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\Anaconda3\\lib\\site-packages\\torch\\nn\\modules\\container.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m     98\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     99\u001b[0m         \u001b[1;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 100\u001b[1;33m             \u001b[0minput\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    101\u001b[0m         \u001b[1;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    102\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\Anaconda3\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m    530\u001b[0m             \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    531\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 532\u001b[1;33m             \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    533\u001b[0m         \u001b[1;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    534\u001b[0m             \u001b[0mhook_result\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-15-76ac93d93836>\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m     16\u001b[0m         \u001b[0msuper\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mFlattenLayer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     17\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;31m# x 的形状: (batch, *, *, ...)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 18\u001b[1;33m         \u001b[1;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mview\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     19\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     20\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0mcollections\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mOrderedDict\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mAttributeError\u001b[0m: 'NDArray' object has no attribute 'view'"
     ]
    }
   ],
   "source": [
    "num_epochs = 5\n",
    "d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, None, None, optimizer)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
